class TreeNode:
    def __init__(self, val: int):
        self.left = None
        self.right = None
        self.val = val


def lowest_common_ancestor(root: TreeNode, p: TreeNode, q: TreeNode):
    if not root or root == p or root == q:
        return root
    left = lowest_common_ancestor(root.left, p, q)
    right = lowest_common_ancestor(root.right, p, q)
    if not left:
        return right
    if not right:
        return left
    return root
